# improved_walker2d.py
import gym
import numpy as np
from gym import spaces
from gym.envs.registration import register


class ImprovedWalker2dEnv(gym.Env):
    """
    A lightweight, MuJoCo-free Walker2d-style environment.

    - 2D walker with a torso and two legs (hip/knee/ankle each -> 6 joints)
    - Continuous action space: torques for 6 joints in [-1, 1]
    - Observation includes positions, angles, velocities and foot contacts
    - Optional noise on actions, dynamics and observations (for robustness)

    State layout (size = 20):
      [0:2]     torso (x, z)
      [2]       torso angle (pitch)
      [3:9]     joint angles [Lhip, Lknee, Lankle, Rhip, Rknee, Rankle]
      [9:11]    torso linear vel (x_dot, z_dot)
      [11]      torso ang vel (pitch_dot)
      [12:18]   joint angular vels (same order as angles)
      [18:20]   foot contacts [Lfoot, Rfoot] in {0,1}
    """

    metadata = {"render.modes": ["human"]}

    def __init__(
        self,
        action_noise_scale: float = 0.03,
        dynamics_noise_scale: float = 0.02,
        obs_noise_scale: float = 0.01,
        max_steps: int = 1000,
    ):
        super().__init__()

        # --- Action/Observation spaces ---
        self.action_space = spaces.Box(low=-1.0, high=1.0, shape=(6,), dtype=np.float32)
        self.observation_space = spaces.Box(low=-20.0, high=20.0, shape=(20,), dtype=np.float32)

        # --- Noise params ---
        self.action_noise_scale = float(action_noise_scale)
        self.dynamics_noise_scale = float(dynamics_noise_scale)
        self.obs_noise_scale = float(obs_noise_scale)

        # --- Physics params (simple, stable) ---
        self.dt = 0.04
        self.mass = 3.0
        self.gravity = -9.8
        self.joint_damping = 0.12
        self.joint_coupling = 0.25
        self.ang_momentum_coupling = 0.15
        self.momentum = 0.9
        self.torso_width = 0.25
        self.leg_len = 0.9
        self.ground_z = 0.0

        # Contact/drag
        self.contact_push = 6.0
        self.air_drag = 0.98
        self.ground_drag = 0.65

        # Reward weights
        self.w_fwd = 2.25
        self.alive_bonus = 1.0
        self.h_bonus = 0.4
        self.w_energy = 0.03
        self.w_posture = 0.5
        self.stumble_w = 0.03

        # Termination thresholds
        self.min_height = 0.5
        self.max_torso_angle = np.deg2rad(60.0)

        # Episode bookkeeping
        self.max_steps = int(max_steps)
        self.steps = 0

        # State
        self.torso_pos = np.array([0.0, 1.2], dtype=np.float32)
        self.torso_vel = np.zeros(2, dtype=np.float32)
        self.torso_ang = 0.0
        self.torso_ang_vel = 0.0
        self.q = np.zeros(6, dtype=np.float32)
        self.qd = np.zeros(6, dtype=np.float32)
        self.contact = np.array([0, 0], dtype=np.int32)
        self.prev_foot_xy = np.zeros((2, 2), dtype=np.float32)

        self.reset()

    # -----------------------------
    #  Helpers
    # -----------------------------
    def _foot_positions(self):
        """Approximate foot (x, z) from hip and joint angles."""
        hip_left_x = self.torso_pos[0] - 0.5 * self.torso_width
        hip_right_x = self.torso_pos[0] + 0.5 * self.torso_width
        hip_z = self.torso_pos[1]

        L_eff = self.q[0] + 0.7 * self.q[1] + 0.3 * self.q[2] + self.torso_ang
        R_eff = self.q[3] + 0.7 * self.q[4] + 0.3 * self.q[5] + self.torso_ang

        Lx = hip_left_x + self.leg_len * np.sin(L_eff)
        Lz = hip_z - self.leg_len * np.cos(L_eff)
        Rx = hip_right_x + self.leg_len * np.sin(R_eff)
        Rz = hip_z - self.leg_len * np.cos(R_eff)
        return np.array([[Lx, Lz], [Rx, Rz]], dtype=np.float32)

    def _compute_contacts(self, feet):
        """Contact if foot z <= ground."""
        return (feet[:, 1] <= self.ground_z + 1e-6).astype(np.int32)

    def _get_obs(self):
        obs = np.zeros(20, dtype=np.float32)
        obs[0:2] = self.torso_pos
        obs[2] = self.torso_ang
        obs[3:9] = self.q
        obs[9:11] = self.torso_vel
        obs[11] = self.torso_ang_vel
        obs[12:18] = self.qd
        obs[18:20] = self.contact

        if self.obs_noise_scale > 0:
            obs = obs + np.random.normal(0.0, self.obs_noise_scale, size=obs.shape).astype(np.float32)
        return obs

    def get_state(self):
        """Return a pure-Python/NumPy state dict safe to pickle."""
        return {
            "steps": int(self.steps),
            # Convert to lists to avoid NumPy array copying issues
            "torso_pos": self.torso_pos.tolist(),
            "torso_vel": self.torso_vel.tolist(),
            "torso_ang": float(self.torso_ang),
            "torso_ang_vel": float(self.torso_ang_vel),
            "q": self.q.tolist(),
            "qd": self.qd.tolist(),
            "contact": self.contact.tolist(),
            "prev_foot_xy": self.prev_foot_xy.tolist(),
            "action_noise_scale": float(self.action_noise_scale),
            "dynamics_noise_scale": float(self.dynamics_noise_scale),
            "obs_noise_scale": float(self.obs_noise_scale),
            "rng_state": np.random.get_state(),
        }

    def set_state(self, state):
        """Restore from a state dict produced by get_state()."""
        try:
            self.steps = int(state["steps"])
            # Create fresh numpy arrays from lists to avoid compatibility issues
            self.torso_pos = np.array(state["torso_pos"], dtype=np.float32)
            self.torso_vel = np.array(state["torso_vel"], dtype=np.float32)
            self.torso_ang = float(state["torso_ang"])
            self.torso_ang_vel = float(state["torso_ang_vel"])
            self.q = np.array(state["q"], dtype=np.float32)
            self.qd = np.array(state["qd"], dtype=np.float32)
            self.contact = np.array(state["contact"], dtype=np.int32)
            self.prev_foot_xy = np.array(state["prev_foot_xy"], dtype=np.float32)
            self.action_noise_scale = float(state.get("action_noise_scale", self.action_noise_scale))
            self.dynamics_noise_scale = float(state.get("dynamics_noise_scale", self.dynamics_noise_scale))
            self.obs_noise_scale = float(state.get("obs_noise_scale", self.obs_noise_scale))
            if "rng_state" in state:
                np.random.set_state(state["rng_state"])
        except Exception as e:
            print(f"Error in set_state: {e}")
            # Fallback: reset to a safe initial state
            self.reset()
            raise RuntimeError(f"Failed to restore state: {e}")

    # -----------------------------
    #  Gym API
    # -----------------------------
    def reset(self):
        self.steps = 0
        self.torso_pos = np.array([0.0, 1.2], dtype=np.float32)
        self.torso_vel = np.zeros(2, dtype=np.float32)
        self.torso_ang = np.random.uniform(-0.05, 0.05)
        self.torso_ang_vel = 0.0
        self.q = np.random.uniform(-0.05, 0.05, size=6).astype(np.float32)
        self.qd = np.zeros(6, dtype=np.float32)

        feet = self._foot_positions()
        self.contact = self._compute_contacts(feet)
        self.prev_foot_xy = feet.copy()

        return self._get_obs()

    def step(self, action):
        self.steps += 1

        # --- Action handling (robust to shapes/views) ---
        a = np.asarray(action, dtype=np.float32)
        if a.shape != (6,):
            a = np.reshape(a, (6,)).astype(np.float32, copy=False)

        if self.action_noise_scale > 0:
            noise = np.random.normal(0.0, self.action_noise_scale, size=a.shape).astype(np.float32)
            a = a + noise  # work on a fresh buffer

        a = np.clip(a, -1.0, 1.0, out=np.empty_like(a))

        # --- Joint torque & coupling -> angular acceleration proxy ---
        tau = 3.0 * a
        tau_c = tau.copy()
        for base in (0, 3):
            tau_c[base + 0] += self.joint_coupling * tau[base + 1]
            tau_c[base + 1] += self.joint_coupling * (tau[base + 0] + tau[base + 2])
            tau_c[base + 2] += self.joint_coupling * tau[base + 1]

        # Angular velocity update with damping
        self.qd += (tau_c - self.joint_damping * self.qd) * self.dt

        # Optional dynamics noise
        if self.dynamics_noise_scale > 0:
            self.qd += np.random.normal(0.0, self.dynamics_noise_scale, size=self.qd.shape).astype(np.float32)
            self.torso_vel += np.random.normal(0.0, self.dynamics_noise_scale, size=self.torso_vel.shape).astype(np.float32)
            self.torso_ang_vel += float(np.random.normal(0.0, self.dynamics_noise_scale))

        # Integrate joint angles with bounds
        self.q += self.qd * self.dt
        self.q = np.clip(self.q, -1.2, 1.2)

        # --- Foot contacts & ground interaction ---
        feet = self._foot_positions()
        self.contact = self._compute_contacts(feet)

        push_forward = 0.0
        push_up = 0.0
        for leg_i in range(2):
            if self.contact[leg_i] == 1:
                if self.torso_vel[1] < 0:
                    push_up += -self.torso_vel[1] * self.mass * 0.5
                leg_ang_speed = np.abs(self.qd[leg_i * 3:(leg_i + 1) * 3]).sum()
                push_forward += self.contact_push * 0.2 * leg_ang_speed
                push_up += self.contact_push * 0.1

        # --- Torso dynamics ---
        fwd_drive = self.ang_momentum_coupling * np.abs(self.qd).sum() + push_forward
        up_drive = push_up + self.gravity

        self.torso_vel[0] = self.momentum * self.torso_vel[0] + fwd_drive * self.dt
        self.torso_vel[1] = self.momentum * self.torso_vel[1] + up_drive * self.dt

        drag = self.ground_drag if np.any(self.contact) else self.air_drag
        self.torso_vel *= drag

        self.torso_pos += self.torso_vel * self.dt
        if self.torso_pos[1] < self.min_height:
            self.torso_pos[1] = self.min_height
            if self.torso_vel[1] < 0:
                self.torso_vel[1] = -0.3 * self.torso_vel[1]

        balance_torque = -0.2 * self.torso_ang - 0.08 * self.torso_ang_vel
        self.torso_ang_vel += balance_torque * self.dt
        self.torso_ang += self.torso_ang_vel * self.dt
        self.torso_ang = np.clip(self.torso_ang, -np.pi, np.pi)

        # --- Reward ---
        forward_vel = float(self.torso_vel[0])
        height = float(self.torso_pos[1])
        energy = float(np.dot(a, a))
        posture_pen = float(self.torso_ang ** 2)

        slip_pen = 0.0
        new_feet = self._foot_positions()
        for i in range(2):
            if self.contact[i] == 1:
                slip = float(np.abs(new_feet[i, 0] - self.prev_foot_xy[i, 0]))
                slip_pen += slip
        self.prev_foot_xy = new_feet.copy()

        reward = (
            self.w_fwd * forward_vel
            + self.alive_bonus
            + self.h_bonus * height
            - self.w_energy * energy
            - self.w_posture * posture_pen
            - self.stumble_w * slip_pen
        )

        done = (
            (height < 0.3)
            or (np.abs(self.torso_ang) > self.max_torso_angle)
            or (self.steps >= self.max_steps)
        )

        info = {}
        return self._get_obs(), float(reward), bool(done), info

    def render(self, mode="human"):
        if mode == "human":
            feet = self._foot_positions()
            print(
                f"x={self.torso_pos[0]:.2f} z={self.torso_pos[1]:.2f} "
                f"vx={self.torso_vel[0]:.2f} vz={self.torso_vel[1]:.2f} "
                f"ang={np.rad2deg(self.torso_ang):.1f}° "
                f"Lc={self.contact[0]} Rc={self.contact[1]} "
                f"Lfoot=({feet[0,0]:.2f},{feet[0,1]:.2f}) "
                f"Rfoot=({feet[1,0]:.2f},{feet[1,1]:.2f})"
            )


# ---- Gym registration ----
register(
    id="ImprovedWalker2d-v0",
    entry_point="improved_walker2d:ImprovedWalker2dEnv",
    max_episode_steps=1000,
)
